Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fixes for rstar #52

Merged
merged 24 commits into from
Dec 14, 2022
Merged

Implement fixes for rstar #52

merged 24 commits into from
Dec 14, 2022

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Dec 13, 2022

Fixes #51:

  • uses a 70-30 split for training/test data
  • adds a keyword argument nsplit=2 split_chains=2 (as suggested in Add rank-normalized ESS and other variants #22 (comment)) to control how many chains each individual chain is split into (nsplit=1 split_chains=1 is the old behavior) to check for within-chain convergence.
  • uses stratified sampling to split, which ensures that frac draws for each chain are in the training data

I'm not thrilled with the name nsplit, as it's not terribly descriptive, but I haven't thought of a better one that wasn't quite verbose.

As suggested in #51 (comment), we consider these changes non-breaking because they make the defaults consistent with the recommendations in the paper.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
test/utils.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

I'm not thrilled with the name nsplit, as it's not terribly descriptive, but I haven't thought of a better one that wasn't quite verbose.

Maybe we could use even just split? Or split_chains?

@sethaxen
Copy link
Member Author

Maybe we could use even just split? Or split_chains?

What I don't like about these is that if the user forgets split or split_chains is a number of splits, they might set split_chains=true thinking its boolean, and accidentally disable splitting. At least nsplit conveys that it's the number of splits. Something like num_chain_splits is a little ugly but maybe less prone to mistakes?

@coveralls
Copy link

coveralls commented Dec 13, 2022

Pull Request Test Coverage Report for Build 3686560314

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 55 of 55 (100.0%) changed or added relevant lines in 2 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.4%) to 94.903%

Totals Coverage Status
Change from base Build 3680569701: 0.4%
Covered Lines: 633
Relevant Lines: 667

💛 - Coveralls

@codecov
Copy link

codecov bot commented Dec 13, 2022

Codecov Report

Base: 94.50% // Head: 94.91% // Increases project coverage by +0.41% 🎉

Coverage data is based on head (3a288fa) compared to base (8d74357).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #52      +/-   ##
==========================================
+ Coverage   94.50%   94.91%   +0.41%     
==========================================
  Files           9       10       +1     
  Lines         619      669      +50     
==========================================
+ Hits          585      635      +50     
  Misses         34       34              
Impacted Files Coverage Δ
src/rstar.jl 100.00% <100.00%> (ø)
src/utils.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@devmotion
Copy link
Member

What I don't like about these is that if the user forgets split or split_chains is a number of splits, they might set split_chains=true thinking its boolean, and accidentally disable splitting. At least nsplit conveys that it's the number of splits. Something like num_chain_splits is a little ugly but maybe less prone to mistakes?

If we enforce split::Int the function will error if someone sets split=true:

julia> f(; split::Int=2) = split
f (generic function with 1 method)

julia> f()
2

julia> f(; split=4)
4

julia> f(; split=true)
ERROR: TypeError: in keyword argument split, expected Int64, got a value of type Bool
Stacktrace:
 [1] top-level scope
   @ REPL[4]:1

@sethaxen
Copy link
Member Author

Ah, of course! Then in that case, I prefer split_chains. I'll update.

src/utils.jl Outdated
Comment on lines 10 to 14
if haskey(d, xi)
push!(d[xi], i)
else
d[xi] = [i]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be made more efficient by not looking up the key twice. One could e.g. use

Suggested change
if haskey(d, xi)
push!(d[xi], i)
else
d[xi] = [i]
end
d_xi = get!(d, xi) do
return Int[]
end
push!(d_xi, i)

Apart from that, it seems like a function that could exist e.g. in StatsBase (similar to proportionmap etc.). Did you check that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this would fit in StatsBase, but there currently is no such method (indexmap is closest). MLUtils has group_indices, which is equivalent, but the dependency is too heavy.

I found a few threads of people looking for this, e.g. https://discourse.julialang.org/t/is-there-a-function-similar-to-numpy-unique-with-inverse/80949 but with no clear answer.

An alternative would be to stick closer to NumPy's very useful return_inverse=True approach and return 2 vectors, basically the sorted keys and corresponding values. Either way, this could later be upstreamed to StatsBase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I just want to make sure we use existing functionality. If it doesn't exist yet that's unfortunate but, of course, then we should use our own implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it seems indicatormap returns the information we are interested in: https://juliastats.org/StatsBase.jl/stable/misc/#StatsBase.indicatormat But maybe it's not the desired output format for our purposes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's similar, yes, but a little clunky. e.g. here's how we could get the vector of indices:

using SparseArrays
map(first  findnz  sparse, eachslice(indicatormat(x; sparse=true); dims=1))

But I still think it makes more sense to try to upstream the functionality we want, since often something like what we want will be more convenient for the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. That seems a bit inconvenient.

src/utils.jl Outdated Show resolved Hide resolved
end

"""
split_chain_indices(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there some existing splitting functionality for ess? Is the plan to merge these eventually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite merge, because there are two different types of splitting we can consider. This approach supports ragged chains and is as a result more complex and doesn't discard any draws (instead dividing the remainder across the earlier splits).

For ess/rhat, we don't support ragged chains so would discard draws if necessary to keep them the same length after splitting. This implementation is much simpler and can be done in a non-allocating way with just reshape and view on a 3d array. This will be part of #22.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing splitting functionality copy_split! will go away.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

@devmotion I have implemented all of your suggestions

The failed integration test with MCMCChains is expected. That particular failure checks whether two chains, each of constant value can be perfectly discriminated by the classifier. With the new default of split_chains=2, perfect discrimination is no longer possible, so R* is not 2. That test can be updated to use split_chains=1 after this PR is merged and a new version is released.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
sethaxen and others added 2 commits December 13, 2022 21:49
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you!

@sethaxen
Copy link
Member Author

Looks good to me, thank you!

No problem, thanks for the detailed reviews!

@sethaxen
Copy link
Member Author

Ah, it seems I do not have permissions to merge, since the integration test fails. Can you merge?

@devmotion devmotion merged commit 63e82cf into main Dec 14, 2022
@delete-merged-branch delete-merged-branch bot deleted the rstarfixes branch December 14, 2022 01:15
@coveralls
Copy link

coveralls commented Oct 11, 2024

Pull Request Test Coverage Report for Build 3689690844

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 57 of 57 (100.0%) changed or added relevant lines in 2 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.4%) to 94.918%

Totals Coverage Status
Change from base Build 3680569701: 0.4%
Covered Lines: 635
Relevant Lines: 669

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Proposed changes to rstar defaults
3 participants